iT邦幫忙

2022 iThome 鐵人賽

DAY 8
0
AI & Data

JAX 好好玩系列 第 8

JAX 好好玩 (8) : JAX.NUMPY (4) : 用了才知道它的快

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

極速實驗的結果

現在我們要來看看 jax.numpy 執行起來比傳統的 Numpy 快多少?老頭分別做了以下 6 種試驗:

  • Numpy 在 CPU 上計算 ndarray 資料
  • jax.numpy 在 CPU 上計算 ndarray 資料
  • jax.numpy 在 CPU 上計算 DeviceArray 資料
  • jax.numpy 在 GPU 上計算 ndarray 資料
  • jax.numpy 在 GPU 上計算 DeviceArray 資料

所使用的 Numpy 資料和 DeviceArray 資料如下:

shape = (10000, 10000)
 
# numpy data
np.random.seed(0)
x_np  = np.random.normal(size=shape).astype(np.float32)
 
# jax data
key = jax.random.PRNGKey(0)
x_jax = jax.random.normal(key, shape, dtype=jnp.float32)

所用的計算則是 dot()。

np.dot()
jnp.dot()

這些試驗是在 colab 上執行,colab 分配給我的 CPU 及 GPU 型號分別是:

  1. CPU: colab 配置兩個 processor 的虛擬機 (VM), 規格都是:

processor : 0 (or 1)
vendor_id : GenuineIntel
cpu family : 6
model : 79
model name : Intel(R) Xeon(R) CPU @ 2.20GHz
stepping : 0
microcode : 0x1
cpu MHz : 2199.998
cache size : 56320 KB

  1. GPU: colab 配置一個 GPU, 規格是:

GPU 0: Tesla T4 (UUID: GPU-d515ed62-b83e-595b-a63d-1c11b56a2197)

最終的結果是這樣的:

https://ithelp.ithome.com.tw/upload/images/20220919/20129616Axup0GjcmG.png

首先比較試驗編號 E 和 A,它們代表了最典型的 JAX 運算及 Numpy 運算,JAX 的執行速度足足比 numpy 快了 28 倍!效率的提升相當驚人。

另外,試驗編號 D 比編號 E 稍微多了一點時間,依據 JAX 文件的說法,這額外的時間是花在將 ndarray 轉成 DeviceArray 再移至 GPU 上面。

而從編號 B 和 C 的結果,我們可以了解到 jax.numpy 在純 CPU 的環境下,執行的效率比較不好,比典型的 Numpy 多了兩倍時間,這是老頭之前所沒有預料到的。未來得多做一些實驗來觀察 jax.numpy 在 CPU 下的表現。

接下來分別說明這些實驗的執行過程。

On CPU

首先我們要把 colab 執行環境切到 CPU only,以執行實驗 A,B 和 C:

編輯 → 筆記本設定 → 硬體加速器 : 選取 None
選好後記得按「儲存」

https://ithelp.ithome.com.tw/upload/images/20220919/20129616mXIf0mr06n.png

把該有的程式庫 import 進來:

import jax
import jax.numpy as jnp
import numpy as np

宣告 10000 x 10000 的亂數矩陣:

shape = (10000, 10000)

# numpy data
np.random.seed(0)
x_np  = np.random.normal(size=shape).astype(np.float32)

# jax data
key = jax.random.PRNGKey(0)
x_jax = jax.random.normal(key, shape, dtype=jnp.float32)

可以執行測試了:

# test A
%timeit np.dot(x_np, x_np.T)  

# test B
%timeit jnp.dot(x_np, x_np.T).block_until_ready()  

# test C
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() 

眼尖的讀者可能發覺在執行 jax.numpy API dot() 時,老頭附帶了 “block_until_ready()”,原因是 JAX 使用了「asynchronous dispatch」[8.1],在執行 dot() 等 JAX 運算時,JAX 會非同步的執行運算,而儘早的將控制權交還給 Python,並這會造成 %timeit 所得到的時間太短,而不能顯示出 dot 運算真正的執行時間。“block_until_ready()” 即是告訴 JAX,等計算結果出來後,才釋出控制權。這樣, %timeit 所得到的時間才是我們要的。

On GPU

其次我們要在 colab 上把執行環境切回到使用 GPU 做為硬體加速器,再來做實驗 D 和 E:

編輯 → 筆記本設定 → 硬體加速器 : 選取 GPU
選好後記得按「儲存」

https://ithelp.ithome.com.tw/upload/images/20220919/20129616824ba5CSey.png

程式庫的 import 和宣告亂數矩陣,和 On CPU 完全相同。

執行另外兩個測試:

# test D
%timeit jnp.dot(x_np, x_np.T).block_until_ready()

# test E
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()

讀者可以在老頭提供的 colab 筆記本上,自己做一下實驗,看看你自己做出來的結果如何。

註:

[8.1] 有關 asynchronous dispatch,可參考 JAX 文件 Asynchronous dispatch


上一篇
JAX 好好玩 (7) : JAX.NUMPY (3) : 再探 JAX PRNG
下一篇
JAX 好好玩 (9) : JAX.NUMPY (5) : DeviceArray 初探
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言